import json
import os
from langchain import PromptTemplate, LLMChain
from langchain.chat_models import ChatOpenAI
from langchain.schema import BaseOutputParser


prompt_template = '''You are presented with an abstract reasoning problem represented as following:

{problem}

Your task is to imagine and visualize the problem using a drawing tool. You have the option to choose from the following tools:

- Python matplotlib: This tool allows you to create visualizations and drawings using Python code. It is suitable for solving abstract problems that can be represented graphically. Matplotlib is suitable for creating line charts, bar charts, pie charts, and other types of graphs.
- Python turtle: Python turtle graphics provides a way to create drawings and patterns using a turtle that moves around the screen. It is suitable for solving problems that involve geometric shapes and patterns. Turtle is suitable for drawing various geometric shapes.
- Stable diffusion: Stable diffusion is a model that can generate images from text prompt. The length of text prompt is less than 100. It is suitable for text-to-image tasks and can be used to visualize abstract problems. The stable diffusion model is suitable for creating imaginative visuals rather than standardized charts. 

Please choose your tool according to the property of given problem.

Your output should be in JSON format directly and contain the following fields:

require_imagine: Indicates whether the problem requires imagination.
tool: Specifies the drawing tool you choose.
code: (Optional)If you use matplotlib or turtle, you should provide any code that draw figure with your chosen tool, the code should save to 'figure.png'. 
prompt: (Optional)If you are using the stable diffusion, the 'code' field will be changed to the 'prompt' field to indicate the textual prompts you provide to the text-to-image model.

Example Prompt:
{{
"require_imagine": true,
"tool": "Python matplotlib",
"code": "import matplotlib\nfig, ax = plt.subplots()\n..."
}}'''

extract_code_prompt = '''Please extract the code (remove the `plt.show()` or any other iteractive command) within the following content:
{result}
\OUTPUT DIRECTLY WITHOUT ANY OTHER INFORMATION!'''

extract_prompt_prompt = '''Please extract content of prompt within the following content:
{result}
\OUTPUT DIRECTLY WITHOUT ANY OTHER INFORMATION!'''


class SetupOutputParser(BaseOutputParser):

    def parse(self, text: str):
        """Parse the output of an LLM call."""
        try:
            start = text.index('{')
            end = text.index('}')
        except:
            return text
        structured_text = text[start: end+1]
        print(start, end)
        print(structured_text)

        return json.loads(structured_text)


model = 'gpt-3.5-turbo'


# tasks = ['creative_writing', 'intersect_geometry', 'sudoku', 'time_series_prediction']
tasks = ['time_series_prediction']

for task in tasks:
    task_metadata = os.path.join('..', 'dataset_small', task, 'task.json')
    with open(task_metadata, 'r', encoding='utf8') as f:
        data = json.load(f)
    imaginary_data = []
    for item in data:
        problem = item['input']
        llm = ChatOpenAI(
            model_name=model,
            temperature=0
        )

        llm_chain = LLMChain(
            llm=llm,
            prompt=PromptTemplate.from_template(prompt_template)
        )

        raw_output = llm_chain.predict(problem=problem)
        print(raw_output)
        json_result = SetupOutputParser().parse(raw_output.replace('\n', ''))
        print('=' * 20)
        print(json_result)

        is_text_to_img_model = ('stable diffusion' in json_result['tool'].lower())

        llm_chain = LLMChain(
            llm=llm,
            prompt=PromptTemplate.from_template(extract_prompt_prompt if is_text_to_img_model else extract_code_prompt)
        )
        code_or_prompt = llm_chain.predict(result=raw_output)
        print('=' * 20)
        print(code_or_prompt)

        with open(os.path.join('temp', task + '.txt'), 'a+', encoding='utf8') as f:
            f.write('=' * 40 + '\n')
            f.write(problem + '\n')
            f.write('-' * 40 + '\n')
            f.write(code_or_prompt + '\n\n')
